package com.only4play.common.utils;

import com.nimbusds.jose.jwk.Curve;
import com.nimbusds.jose.jwk.ECKey;
import com.nimbusds.jose.jwk.OctetSequenceKey;
import com.nimbusds.jose.jwk.RSAKey;
import org.apache.commons.io.IOUtils;

import javax.crypto.Cipher;
import javax.crypto.SecretKey;
import java.io.ByteArrayOutputStream;
import java.nio.charset.StandardCharsets;
import java.security.KeyFactory;
import java.security.KeyPair;
import java.security.PrivateKey;
import java.security.PublicKey;
import java.security.interfaces.ECPrivateKey;
import java.security.interfaces.ECPublicKey;
import java.security.interfaces.RSAPrivateKey;
import java.security.interfaces.RSAPublicKey;
import java.security.spec.PKCS8EncodedKeySpec;
import java.security.spec.X509EncodedKeySpec;
import java.util.Base64;
import java.util.UUID;

/**
 * @author liyuncong
 * @version 1.0
 * @file JwksUtils
 * @brief JwksUtils
 * @details JwksUtils
 * @date 2023-12-12
 *
 * Edit History
 * ----------------------------------------------------------------------------
 * DATE                     NAME               DESCRIPTION
 * 2023-12-12               liyuncong          Created
 */
public class JwksUtils {

    private static final String RSA_ALGORITHM = "RSA";
    private static final String EC_ALGORITHM = "EC";
    private static final String HMAC_SHA256_ALGORITHM = "HmacSha256";
    private static final int KEY_SIZE = 2048;

    private JwksUtils() {
    }

    /**
     * 生成RSA加密key (即JWK)
     */
    public static RSAKey generateRsa() {
        // 生成RSA加密的key
        KeyPair keyPair = KeyGeneratorUtils.generateRsaKey();
        // 公钥
        RSAPublicKey publicKey = (RSAPublicKey) keyPair.getPublic();
        // 私钥
        RSAPrivateKey privateKey = (RSAPrivateKey) keyPair.getPrivate();
        // 构建RSA加密key
        return new RSAKey.Builder(publicKey)
            .privateKey(privateKey)
            .keyID(UUID.randomUUID().toString())
            .build();
    }


    /**
     * 使用公钥进行加密
     *
     * @param publicKey 公钥对象
     * @param data      待加密数据
     * @return 加密后的数据
     */
    public static byte[] encrypt(PublicKey publicKey, byte[] data) {
        try {
            Cipher cipher = Cipher.getInstance(RSA_ALGORITHM);
            cipher.init(Cipher.ENCRYPT_MODE, publicKey);
            return cipher.doFinal(data);
        } catch (Exception e) {
            throw new IllegalStateException(e);
        }
    }

    /**
     * RSA使用公钥进行加密
     *
     * @param publicKey 公钥对象
     * @param data      待加密数据
     * @return 加密后的数据
     */
    public static byte[] encrypt(RSAPublicKey publicKey, byte[] data) {
        try {
            Cipher cipher = Cipher.getInstance(RSA_ALGORITHM);
            cipher.init(Cipher.ENCRYPT_MODE, publicKey);
            return splitCodec(cipher, Cipher.ENCRYPT_MODE, data, publicKey.getModulus().bitLength());
        } catch (Exception e) {
            throw new IllegalStateException(e);
        }
    }

    public static String encrypt(RSAPublicKey publicKey, String data) {
        try {
            Cipher cipher = Cipher.getInstance(RSA_ALGORITHM);
            cipher.init(Cipher.ENCRYPT_MODE, publicKey);
            byte[] bytes = splitCodec(cipher, Cipher.ENCRYPT_MODE, data.getBytes(StandardCharsets.UTF_8),
                publicKey.getModulus().bitLength());
            return Base64.getUrlEncoder().encodeToString(bytes);
        } catch (Exception e) {
            throw new IllegalStateException(e);
        }
    }

    /**
     * 使用私钥进行解密
     *
     * @param privateKey 私钥对象
     * @param encrypted  加密后的数据
     * @return 解密后的数据
     */
    public static byte[] decrypt(PrivateKey privateKey, byte[] encrypted) {
        try {
            Cipher cipher = Cipher.getInstance(RSA_ALGORITHM);
            cipher.init(Cipher.DECRYPT_MODE, privateKey);
            return cipher.doFinal(encrypted);
        } catch (Exception e) {
            throw new IllegalStateException(e);
        }
    }

    /**
     * 使用私钥进行解密
     *
     * @param privateKey 私钥对象
     * @param encrypted  加密后的数据
     * @return 解密后的数据
     */
    public static byte[] decrypt(RSAPrivateKey privateKey, byte[] encrypted) {
        try {
            Cipher cipher = Cipher.getInstance(RSA_ALGORITHM);
            cipher.init(Cipher.DECRYPT_MODE, privateKey);
            return splitCodec(cipher, Cipher.DECRYPT_MODE, encrypted, privateKey.getModulus().bitLength());
        } catch (Exception e) {
            throw new IllegalStateException(e);
        }
    }

    /**
     * 使用私钥进行解密
     *
     * @param privateKey 私钥对象
     * @param data       加密后的数据
     * @return 解密后的数据
     */
    public static String decrypt(RSAPrivateKey privateKey, String data) {
        try {
            Cipher cipher = Cipher.getInstance(RSA_ALGORITHM);
            cipher.init(Cipher.DECRYPT_MODE, privateKey);
            byte[] decode = Base64.getUrlDecoder().decode(data);
            byte[] bytes = splitCodec(cipher, Cipher.DECRYPT_MODE, decode, privateKey.getModulus().bitLength());
            return new String(bytes);
        } catch (Exception e) {
            throw new IllegalStateException(e);
        }
    }

    /**
     * 分段处理加解密
     *
     * @param cipher  cipher
     * @param opmode  opmode
     * @param datas   datas
     * @param keySize keySize
     * @return byte
     */
    private static byte[] splitCodec(Cipher cipher, int opmode, byte[] datas, int keySize) {
        int maxBlock = 0;
        if (opmode == Cipher.DECRYPT_MODE) {
            maxBlock = keySize / 8;
        } else {
            maxBlock = keySize / 8 - 11;
        }
        ByteArrayOutputStream out = new ByteArrayOutputStream();
        int offSet = 0;
        byte[] buff;
        int i = 0;
        try {
            while (datas.length > offSet) {
                if (datas.length - offSet > maxBlock) {
                    buff = cipher.doFinal(datas, offSet, maxBlock);
                } else {
                    buff = cipher.doFinal(datas, offSet, datas.length - offSet);
                }
                out.write(buff, 0, buff.length);
                i++;
                offSet = i * maxBlock;
            }
        } catch (Exception e) {
            throw new RuntimeException("加解密阀值为[" + maxBlock + "]的数据时发生异常", e);
        }
        byte[] result = out.toByteArray();
        IOUtils.closeQuietly(out);
        return result;
    }

    /**
     * 将公钥转换为Base64编码字符串
     *
     * @param publicKey 公钥对象
     * @return Base64编码字符串
     */
    public static String publicKeyToString(PublicKey publicKey) {
        byte[] encodedKey = publicKey.getEncoded();
        return Base64.getEncoder().encodeToString(encodedKey);
    }

    /**
     * 将Base64编码字符串转换为公钥对象
     *
     * @param publicKeyString Base64编码字符串
     * @return 公钥对象
     */
    public static PublicKey publicKeyFromString(String publicKeyString) {
        try {
            byte[] keyBytes = Base64.getDecoder().decode(publicKeyString);
            KeyFactory keyFactory = KeyFactory.getInstance(RSA_ALGORITHM);
            X509EncodedKeySpec keySpec = new X509EncodedKeySpec(keyBytes);
            return keyFactory.generatePublic(keySpec);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * 将私钥转换为Base64编码字符串
     *
     * @param privateKey 私钥对象
     * @return Base64编码字符串
     */
    public static String privateKeyToString(PrivateKey privateKey) {
        byte[] encodedKey = privateKey.getEncoded();
        return Base64.getEncoder().encodeToString(encodedKey);
    }

    /**
     * 将Base64编码字符串转换为私钥对象
     *
     * @param privateKeyString Base64编码字符串
     * @return 私钥对象
     */
    public static PrivateKey privateKeyFromString(String privateKeyString) {
        try {
            byte[] keyBytes = Base64.getDecoder().decode(privateKeyString);
            KeyFactory keyFactory = KeyFactory.getInstance(RSA_ALGORITHM);
            PKCS8EncodedKeySpec keySpec = new PKCS8EncodedKeySpec(keyBytes);
            return keyFactory.generatePrivate(keySpec);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }


    /**
     * 生成EC加密key (即JWK)
     */
    public static ECKey generateEc() {
        // 生成EC加密的key
        KeyPair keyPair = KeyGeneratorUtils.generateEcKey();
        // 公钥
        ECPublicKey publicKey = (ECPublicKey) keyPair.getPublic();
        // 私钥
        ECPrivateKey privateKey = (ECPrivateKey) keyPair.getPrivate();
        // 根据公钥参数生成曲线
        Curve curve = Curve.forECParameterSpec(publicKey.getParams());
        // 构建EC加密key
        return new ECKey.Builder(curve, publicKey)
            .privateKey(privateKey)
            .keyID(UUID.randomUUID().toString())
            .build();
    }

    /**
     * 生成HmacSha256密钥
     */
    public static OctetSequenceKey generateSecret() {
        SecretKey secretKey = KeyGeneratorUtils.generateSecretKey();
        return new OctetSequenceKey.Builder(secretKey)
            .keyID(UUID.randomUUID().toString())
            .build();
    }

    public static void main(String[] args) throws Exception {
        String data = "Hello, World!Hello, World!Hello, World!Hello, World!Hello, World!Hello, World!Hello," +
            " World!Hello, World!Hello, World!Hello, World!Hello, World!Hello, World!Hello, World!Hello, World!Hello, " +
            "World!Hello, World!Hello, World!Hello, World!Hello, World!Hello, World!";

        RSAKey rsaKey = generateRsa();

        RSAPublicKey publicKey = rsaKey.toRSAPublicKey();
        RSAPrivateKey privateKey = rsaKey.toRSAPrivateKey();

        String str1 = encrypt(publicKey, data);
        String str2 = decrypt(privateKey, str1);

        System.out.println("Encrypted str1: " + str1);
        System.out.println("Decrypted str2: " + str2);

    }
}
